import sigpy as sp
import sigpy.mri as mr
import sigpy.plot as pl
import scipy.io as sio
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from utils import undersample
from utils import ss_combine
from skimage import metrics
from sense import sense

# Load image and coil sensitivity data
data = sio.loadmat('brain_8ch.mat')
im = np.transpose(data['im'], (2, 0, 1))     # (8, 160, 220) Image
map = np.transpose(data['map'], (2, 0, 1))   # (8, 160, 220) Coil Sensitivities

im_p = pl.ImagePlot(im, z=0, hide_axes=True)      # Images
map_p = pl.ImagePlot(map, z=0, hide_axes=True)    # Coil Sensitivities

# Undersampled image
imu = undersample(im, 1, 4)
imu_p = pl.ImagePlot(imu, z=0, hide_axes=True)    # Undersampled image Rx = 2

# k-space data
ksp = sp.fft(im, axes=(1,2))        # Fully sampled kspace (8, 160, 220)
kspu = sp.fft(imu, axes=(1,2))      # 2x undersampled in ky (8, 160, 220)

"""
Solve inverse problem to get GRAPPA kernel
Assumes undersampling in x and a (2x3) kernel
"""
def grappa_kernel_R4(ksp, cal_width):
    nc, nkx, nky = ksp.shape
    kspcal = np.transpose(ksp[:, (nkx-cal_width)//2:(nkx+cal_width)//2, :], (1, 2, 0))      # kspace calibration data (ncal, 220, 8)
    ypos = np.array([-2,-1,0,1,2])
    ypad = len(ypos) // 2
    kernels = []
    for k in range(3):
        xpos1 = -k - 1
        xpos2 = 3 - k
        xpos = np.array([xpos1, xpos2])
        kspcal_padded = np.pad(kspcal, ((np.abs(xpos1), xpos2), (ypad, ypad), (0, 0)), 'constant', constant_values=0)

        A = np.zeros([cal_width, nky, len(xpos), len(ypos), nc], dtype=kspcal.dtype)
        B = np.zeros([cal_width, nky, nc], dtype=kspcal.dtype)
        for nx in range(cal_width):
            for ny in range(nky):
                A[nx, ny, ...] = kspcal_padded[(nx + xpos + np.abs(xpos1))[:, None], (ny + ypos + ypad), ...]
                B[nx, ny, ...] = kspcal_padded[nx + np.abs(xpos1), ny + ypad]
        A = np.reshape(A, [cal_width * nky, len(xpos) * len(ypos) * nc])
        B = np.reshape(B, [cal_width * nky, nc])
        X = np.linalg.pinv(A) @ B
        grappa_kernel = np.reshape(X, (len(xpos), len(ypos), nc, nc))
        grappa_kernel = np.insert(grappa_kernel, 1, 0, axis=0)
        grappa_kernel = np.insert(grappa_kernel, 1, 0, axis=0)
        grappa_kernel = np.insert(grappa_kernel, 1, 0, axis=0)
        grappa_kernel = np.transpose(grappa_kernel, [3, 2, 0, 1])   # GRAPPA kernel (nc, nc, xpos, ypos)
        kernels.append(np.flip(grappa_kernel, (-2, -1)))
    return kernels

# Test grappa_kernel on brain_mat Rx = 2, Ry = 1
grappa_kernels = grappa_kernel_R4(ksp, 20)
ksp_grappa = np.copy(kspu)
for k in range(3):
    ksp_grappa_k = np.zeros((ksp_grappa.shape[0], ksp_grappa.shape[1], ksp_grappa.shape[2]), dtype=complex)
    for coil_i in range(im.shape[0]):
        for coil_j in range(im.shape[0]):
            ksp_grappa_k[coil_i, ...] += signal.convolve2d(kspu[coil_j, ...], grappa_kernels[k][coil_i, coil_j, ...], mode='same')
        ksp_grappa_k = np.roll(ksp_grappa_k, k - 1, axis=1)
        ksp_grappa_k[:, k - 1, :] = 0
        ksp_grappa[coil_i, ...] += ksp_grappa_k[coil_i, ...]

# Display the GRAPPA-reconstructed k-space and save the image
im_grappa = sp.ifft(ksp_grappa, axes=(1, 2))
im_grappa_ss = ss_combine(im_grappa)
plt.imshow(np.abs(ss_combine(ksp_grappa)))
plt.show()
plt.imshow(np.abs(ss_combine(im_grappa)), cmap='gray')
plt.savefig("grappa4.png")

# Compute pSNR for R=4 GRAPPA
im_gt = ss_combine(im)
data_range = np.max(np.abs(im_grappa_ss)) - np.min(np.abs(im_grappa_ss))
pSNR = metrics.peak_signal_noise_ratio(np.abs(im_gt), np.abs(im_grappa_ss), data_range=data_range)
print("pSNR for GRAPPA (R = 4) is " + str(pSNR) + " dB")